train_set_name="med_pub_train.json"
test_set_name="med_pub_test.json"
data_dir="./experiments/med_pub_data/"

import json
import os
split_data = {"train": set(),
"test": set()}

with open(os.path.join(data_dir, train_set_name), 'r', encoding='utf-8') as f:
    train_data = json.load(f)
for item in train_data:
    split_data["train"].add(item["images"][0])
with open(os.path.join(data_dir, test_set_name), 'r', encoding='utf-8') as f:
    test_data = json.load(f)
for item in test_data:
    split_data["test"].add(item["images"][0])
print(f"train set size: {len(split_data['train'])}")
print(f"test set size: {len(split_data['test'])}")
# 转化为list
split_data["train"] = list(split_data["train"])
split_data["test"] = list(split_data["test"])
# 保存为json文件
output_file = os.path.join(data_dir, "image_split.json")
with open(output_file, 'w', encoding='utf-8') as f:
    json.dump(split_data, f, ensure_ascii=False, indent=4)
print(f"Image split data saved to {output_file}")